import torch
import numpy as np
import torch.nn as nn




class MIL_softmax_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, tau=0.1, gamma=0.9, eta=1.0, device=None):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(MIL_softmax_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma = gamma
        self.eta = eta
        self.tau = tau
        self.data_length = data_length
        self.s = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.a = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=self.device)
        self.b = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=self.device)
        self.alpha = torch.zeros(1, dtype=torch.float32, requires_grad=False, device=self.device)
        self.threshold = threshold

    def update_smoothing(self, decay_factor):
        self.gamma = self.gamma/decay_factor
        self.eta = self.eta/decay_factor

    def forward(self, exps, y_true, ids): 
        ids = ids.to(self.device) 
        self.s[ids] = (1-self.gamma) * self.s[ids] + self.gamma * exps.detach()
        vs = self.s[ids]
        ids_p = (y_true == 1)
        ids_n = (y_true == 0)
        s_p = vs[ids_p]   
        s_n = vs[ids_n]   
        logs_p = self.tau*torch.log(s_p)
        logs_n = self.tau*torch.log(s_n)
        gw_ins_p = exps[ids_p]/s_p 
        gw_ins_n = exps[ids_n]/s_n
        gw_p = torch.mean(2*self.tau*(logs_p-self.a.detach())*gw_ins_p)
        gw_n = torch.mean(2*self.tau*(logs_n-self.b.detach())*gw_ins_n)
        gw_s = self.alpha.detach()* self.tau * (torch.mean(gw_ins_n) - torch.mean(gw_ins_p))
        ga = torch.mean((logs_p - self.a)**2)
        gb = torch.mean((logs_n - self.b)**2)
        loss = gw_p + gw_n + gw_s + ga + gb
        return loss



class MIL_attention_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, gamma_1=0.9, gamma_2=0.9, eta=1e0, device=None):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(MIL_attention_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device 
        self.gamma_1 = gamma_1
        self.gamma_2 = gamma_2
        self.eta = eta
        self.data_length = data_length
        self.sn = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.sd = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device)      
        self.a = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=self.device)
        self.b = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=self.device)
        self.alpha = torch.zeros(1, dtype=torch.float32, requires_grad=False, device=self.device)
        self.threshold = threshold

    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta = self.eta/decay_factor

    def forward(self, sn, sd, y_true, ids): 
        ids = ids.to(self.device) 
        self.sn[ids] = (1-self.gamma_1) * self.sn[ids] + self.gamma_1 * sn.detach()
        self.sd[ids] = (1-self.gamma_2) * self.sd[ids] + self.gamma_2 * sd.detach()
        vsn = self.sn[ids]
        vsd = torch.clamp(self.sd[ids], min=1e-8)
        snd = vsn / vsd
        snd = torch.sigmoid(snd)
        gsnd = snd * (1-snd)
        ids_p = (y_true == 1)
        ids_n = (y_true == 0)
        snd_p = snd[ids_p]
        snd_n = snd[ids_n]
        gsnd_p = gsnd[ids_p]
        gsnd_n = gsnd[ids_n]
        gw_att_p = gsnd_p*(1/vsd[ids_p]*sn[ids_p] - vsn[ids_p]/(vsd[ids_p]**2)*sd[ids_p]) 
        gw_att_n = gsnd_n*(1/vsd[ids_n]*sn[ids_n] - vsn[ids_n]/(vsd[ids_n]**2)*sd[ids_n])
        gw_p = torch.mean(2*(snd_p-self.a.detach())*gw_att_p)
        gw_n = torch.mean(2*(snd_n-self.b.detach())*gw_att_n)
        gw_s = self.alpha.detach() * (torch.mean(gw_att_n) - torch.mean(gw_att_p))
        ga = torch.mean((snd_p - self.a)**2)
        gb = torch.mean((snd_n - self.b)**2)
        loss = gw_p + gw_n + gw_s + ga + gb
        return loss



class TPMIL_softmax_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, tau=0.1, gamma=0.9, eta=1e-1, rate=0.5, momentum=0.0, device=None):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(TPMIL_softmax_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma_1 = gamma
        self.eta_1 = eta
        self.gamma_2 = gamma
        self.eta_2 = eta
        self.data_length = data_length
        self.rate = rate
        self.tau = tau
        self.v = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.s1 = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.s2 = torch.tensor([0.0]).view(-1, 1).to(self.device) 
        self.threshold = threshold
        self.momentum = momentum
        if momentum > 0.0:
          self.exps = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
          self.comp_u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 

    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.eta_1 = self.eta_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta_2 = self.eta_2/decay_factor

    def forward(self, exps, y_true, ids): 
        ids = ids.to(self.device) 
        momentum_exps = 0.0
        if self.momentum > 0:
          momentum_exps = self.momentum * (exps.detach() - self.exps[ids])
          self.exps[ids] = exps.detach()
        self.v[ids] = (1-self.gamma_1) * self.v[ids] + self.gamma_1 * exps.detach() + momentum_exps
        vv = self.v[ids]
        ids_p = (y_true.squeeze() == 1)
        ids_n = (y_true.squeeze() == 0)
        v_p = vv[ids_p].view(-1,1)
        v_n = vv[ids_n].view(1,-1)
        log_p = self.tau * torch.log(v_p)
        log_n = self.tau * torch.log(v_n)
        mat_log_n = log_n.repeat(len(log_p), 1)
        exp_p = exps[ids_p].view(-1,1)/v_p
        exp_n = exps[ids_n].view(1,-1)/v_n
        mat_exps = exp_n.repeat(len(exp_p), 1)
        loss = (torch.clip(self.threshold - (log_p - mat_log_n), min=0)**2).detach()
        derived_loss = 2 *self.tau* torch.clip(self.threshold - (log_p - mat_log_n), min=0) * (mat_exps - exp_p)
        p1 = (loss > self.s1[ids[ids_p]]).float()
        p2 = (self.u[ids[ids_p]] > self.s2).float()
        tp_loss = (derived_loss/self.rate*p1).mean(dim=-1,keepdim=True)
        tp_loss = (tp_loss/self.rate*p2).mean()
        
        
        
        comp_u = self.s1[ids[ids_p]] + torch.clip(loss.detach()-self.s1[ids[ids_p]], min=0)/self.rate
        comp_u = comp_u.mean(dim=-1, keepdim=True)
        momentum_comp_u = 0.0
        if self.momentum > 0:
          momentum_comp_u = self.momentum * (comp_u - self.comp_u[ids[ids_p]])
          self.comp_u[ids[ids_p]] = comp_u
        self.u[ids[ids_p]] = (1-self.gamma_2) * self.u[ids[ids_p]] + self.gamma_2 * comp_u + momentum_comp_u
        self.s1[ids[ids_p]] -= self.eta_1/len(ids_p) *p2/self.rate* (1-p1.mean(dim=-1, keepdim=True)/self.rate)
        self.s2 -= self.eta_2 * (1-p2.mean()/self.rate)
        
        
        return tp_loss



class TPMIL_attention_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, gamma=0.9, eta=1e-1, rate=0.5, momentum=0.0, device=None):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(TPMIL_attention_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma_1 = gamma
        self.eta_1 = eta
        self.gamma_2 = gamma
        self.eta_2 = eta
        self.data_length = data_length
        self.rate = rate
        self.vn = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.vd = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.s1 = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.s2 = torch.tensor([0.0]).view(-1, 1).to(self.device) 
        self.threshold = threshold
        self.momentum = momentum
        if momentum > 0.0:
          self.ovn = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
          self.ovd = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
          self.comp_u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 



    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.eta_1 = self.eta_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta_2 = self.eta_2/decay_factor

    def forward(self, vn, vd, y_true, ids): 
        ids = ids.to(self.device) 
        momentum_vn = 0.0
        momentum_vd = 0.0
        if self.momentum > 0:
          momentum_vn = self.momentum * (vn.detach() - self.ovn[ids])
          self.ovn[ids] = vn.detach()
          momentum_vd = self.momentum * (vd.detach() - self.ovd[ids])
          self.ovd[ids] = vd.detach()
        self.vn[ids] = (1-self.gamma_1) * self.vn[ids] + self.gamma_1 * vn.detach() + momentum_vn
        self.vd[ids] = (1-self.gamma_1) * self.vd[ids] + self.gamma_1 * vd.detach() + momentum_vd
        vsn = self.vn[ids]
        vsd = torch.clamp(self.vd[ids], min=1e-8)
        vsf = vsn / vsd
        vsf = torch.sigmoid(vsf)
        gvsf = vsf * (1-vsf) * (vn/vsd - vsn/(vsd**2)*vd)
        ids_p = (y_true.squeeze() == 1)
        ids_n = (y_true.squeeze() == 0)
        v_p = vsf[ids_p].view(-1,1)
        v_n = vsf[ids_n].view(1,-1)
        mat_v_n = v_n.repeat(len(v_p), 1)
        g_p = gvsf[ids_p].view(-1,1)
        g_n = gvsf[ids_n].view(1,-1)
        mat_g_n = g_n.repeat(len(g_p), 1)
        loss = (torch.clip(self.threshold - (v_p - mat_v_n), min=0)**2).detach()
        derived_loss = 2 * torch.clip(self.threshold - (v_p - mat_v_n), min=0) * (mat_g_n - g_p)
        p1 = (loss > self.s1[ids[ids_p]]).float()
        p2 = (self.u[ids[ids_p]] > self.s2).float()
        tp_loss = (derived_loss/self.rate*p1).mean(dim=-1,keepdim=True)
        tp_loss = (tp_loss/self.rate*p2).mean()
        
        
        
        comp_u = self.s1[ids[ids_p]] + torch.clip(loss.detach()-self.s1[ids[ids_p]], min=0)/self.rate
        comp_u = comp_u.mean(dim=-1, keepdim=True)
        momentum_comp_u = 0.0
        if self.momentum > 0:
          momentum_comp_u = self.momentum * (comp_u - self.comp_u[ids[ids_p]])
          self.comp_u[ids[ids_p]] = comp_u
        self.u[ids[ids_p]] = (1-self.gamma_2) * self.u[ids[ids_p]] + self.gamma_2 * comp_u + momentum_comp_u
        self.s1[ids[ids_p]] -= self.eta_1/len(ids_p) *p2/self.rate* (1-p1.mean(dim=-1, keepdim=True)/self.rate)
        self.s2 -= self.eta_2 * (1-p2.mean()/self.rate)
        
        
        return tp_loss
        

'''


class TPMIL_softmax_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, tau=0.1, gamma=0.9, eta=1e-1, rate=0.5, device=None):
        super(TPMIL_softmax_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma_1 = gamma
        self.eta_1 = eta
        self.gamma_2 = gamma
        self.eta_2 = eta
        self.data_length = data_length
        self.rate = rate
        self.tau = tau
        self.v = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.s1 = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.s2 = torch.tensor([0.0]).view(-1, 1).to(self.device) 
        self.threshold = threshold

    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.eta_1 = self.eta_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta_2 = self.eta_2/decay_factor

    def forward(self, exps, y_true, ids): 
        ids = ids.to(self.device) 
        vv = self.v[ids]
        ids_p = (y_true.squeeze() == 1)
        ids_n = (y_true.squeeze() == 0)
        log_p = vv[ids_p].view(-1,1)
        log_n = vv[ids_n].view(1,-1)
        mat_log_n = log_n.repeat(len(log_p), 1)
        log_exps = self.tau * torch.log(exps)
        log_exp_p = log_exps[ids_p].view(-1,1)
        log_exp_n = log_exps[ids_n].view(1,-1)
        mat_log_exp_n = log_exp_n.repeat(len(log_exp_p), 1)
        loss = (torch.clip(self.threshold - (log_p - mat_log_n), min=0)**2).detach()
        derived_loss = 2 * torch.clip(self.threshold - (log_p - mat_log_n), min=0) * (mat_log_exp_n - log_exp_p)
        p1 = (loss > self.s1[ids[ids_p]]).float()
        #print(p1)
        p2 = (self.u[ids[ids_p]] > self.s2).float()
        #print(p2)
        tp_loss = (derived_loss/self.rate*p1).mean(dim=-1,keepdim=True)
        tp_loss = (tp_loss/self.rate*p2).mean()
        
        
        
        self.v[ids] = (1-self.gamma_1) * self.v[ids] + self.gamma_1 * log_exps.detach()
        comp_u = self.s1[ids[ids_p]] + torch.clip(loss.detach()-self.s1[ids[ids_p]], min=0)/self.rate
        comp_u = comp_u.mean(dim=-1, keepdim=True)
        self.u[ids[ids_p]] = (1-self.gamma_2) * self.u[ids[ids_p]] + self.gamma_2 * comp_u
        self.s1[ids[ids_p]] -= self.eta_1/len(ids_p) *p2/self.rate* (1-p1.mean(dim=-1, keepdim=True)/self.rate)
        self.s2 -= self.eta_2 * (1-p2.mean()/self.rate)
        
        
        return tp_loss


class TPMIL_softmax_loss(nn.Module):
    def __init__(self, data_length, threshold=0.5, tau=0.1, gamma=0.9, eta=1e-1, rate=0.5, device=None):
        super(TPMIL_softmax_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma_1 = gamma
        self.eta_1 = eta
        self.gamma_2 = gamma
        self.eta_2 = eta
        self.data_length = data_length
        self.rate = rate
        self.tau = tau
        self.v = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.s1 = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.s2 = torch.tensor([0.0]).view(-1, 1).to(self.device) 
        self.threshold = threshold

    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.eta_1 = self.eta_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta_2 = self.eta_2/decay_factor

    def forward(self, exps, y_true, ids): 
        ids = ids.to(self.device) 
        vv = self.v[ids]
        ids_p = (y_true.squeeze() == 1)
        ids_n = (y_true.squeeze() == 0)
        v_p = vv[ids_p].view(-1,1)
        v_n = vv[ids_n].view(1,-1)
        log_p = self.tau * torch.log(v_p)
        log_n = self.tau * torch.log(v_n)
        mat_log_n = log_n.repeat(len(log_p), 1)
        exp_p = exps[ids_p].view(-1,1)/v_p
        exp_n = exps[ids_n].view(1,-1)/v_n
        #print(torch.cat([exps[ids_p].view(-1,1), exps[ids_n].view(-1,1)],dim=-1))
        mat_exps = exp_n.repeat(len(exp_p), 1)
        loss = (torch.clip(self.threshold - (log_p - mat_log_n), min=0)**2).detach()
        derived_loss = 2 *self.tau* torch.clip(self.threshold - (log_p - mat_log_n), min=0) * (mat_exps - exp_p)
        #derived_loss = 2 *self.tau* torch.clip(self.threshold - (log_p - mat_log_n), min=0) * (torch.log(mat_exps) - torch.log(exp_p))
        #derived_loss = (torch.clip(self.threshold - self.tau*(torch.log(exp_p) - torch.log(mat_exps)), min=0)**2).mean()
        p1 = (loss > self.s1[ids[ids_p]]).float()
        #print(p1)
        p2 = (self.u[ids[ids_p]] > self.s2).float()
        #print(p2)
        tp_loss = (derived_loss/self.rate*p1).mean(dim=-1,keepdim=True)
        tp_loss = (tp_loss/self.rate*p2).mean()
        #tp_loss = (torch.clip(self.threshold - self.tau*(torch.log(exp_p) - torch.log(mat_exps)), min=0)**2).mean()
        #print(tp_loss)
        
        
        
        self.v[ids] = (1-self.gamma_1) * self.v[ids] + self.gamma_1 * exps.detach()
        #print(self.v[ids])
        comp_u = self.s1[ids[ids_p]] + torch.clip(loss.detach()-self.s1[ids[ids_p]], min=0)/self.rate
        comp_u = comp_u.mean(dim=-1, keepdim=True)
        self.u[ids[ids_p]] = (1-self.gamma_2) * self.u[ids[ids_p]] + self.gamma_2 * comp_u
        #print(self.u[ids[ids_p]])
        self.s1[ids[ids_p]] -= self.eta_1/len(ids_p) *p2/self.rate* (1-p1.mean(dim=-1, keepdim=True)/self.rate)
        self.s2 -= self.eta_2 * (1-p2.mean()/self.rate)
        
         
        return tp_loss
'''
